{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Computes Fairness/Bias metrics\n",
    "Uses the AIF360 AI Fairness 360 toolkit to compute Fairness/Bias metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip3 install scikit-learn==0.24.1 aif360==0.3.0  tensorflow==2.4.0 nodejs==0.1.1 ipywidgets==7.6.3 lime==0.2.0.1 wget==3.2 #aix360==0.2.1 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wget\n",
    "wget.download(\n",
    "    'https://raw.githubusercontent.com/'\n",
    "    'elyra-ai/component-library/master/claimed_utils.py'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from claimed_utils import unzip\n",
    "import pandas as pd\n",
    "from aif360.datasets import BinaryLabelDataset\n",
    "import pickle\n",
    "from aif360.metrics import BinaryLabelDatasetMetric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @dependency codait_utils.ipynb\n",
    "# @dependency metadata\n",
    "# @param target_column Column name containing the target/prediction value\n",
    "# (the real measured value)\n",
    "# @param protected_column Protected column (like sex, race, age, ...)\n",
    "# Note: column arrays not supported at the moment\n",
    "# @param prediction_column Column name containing the prediction of the model\n",
    "# @param unpriviledged_group_key value containted in the protected_column\n",
    "# indicating a unpriviledged group (e.g. female)\n",
    "# @param priviledged_group_key value containted in the protected_column\n",
    "# indicating a priviledged group (e.g. male)\n",
    "# @param priviledged_group_key value containted in the protected_column\n",
    "# indicating a priviledged group (e.g. male)\n",
    "# @param metadata csv file name of the data\n",
    "# @param metric_output file name of the pickeled metric object\n",
    "# @returns pickeled metric object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_column = os.environ.get('target_column', 'target')\n",
    "protected_column = os.environ.get('protected_column')\n",
    "prediction_column = os.environ.get('prediction_column', 'prediction')\n",
    "unpriviledged_group_key = os.environ.get('unpriviledged_group_key')\n",
    "priviledged_group_key = os.environ.get('priviledged_group_key')\n",
    "metadata = os.environ.get('metadata', 'metadata.csv')\n",
    "metric_output = os.environ.get('metric_output', 'metric.pickle')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !jupyter labextension install @jupyter-widgets/jupyterlab-manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unzip('.', 'model.zip')\n",
    "unzip('.', 'data.zip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"protected_column_index\"] = df[protected_column].apply(\n",
    "    lambda x:\n",
    "    list(df[protected_column].unique()).index(x)\n",
    ")\n",
    "\n",
    "df[\"missclassified\"] = df.apply(\n",
    "    lambda d:\n",
    "    1 if d[target_column] != d[prediction_column] else 0, axis=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unprivileged_groups = [\n",
    "    {'protected_column_index':\n",
    "     df.loc[\n",
    "         df[protected_column] == unpriviledged_group_key\n",
    "     ].iloc[0]['protected_column_index']}\n",
    "]\n",
    "privileged_groups = [\n",
    "    {'protected_column_index':\n",
    "     df.loc[\n",
    "         df[protected_column] == priviledged_group_key\n",
    "     ].iloc[0]['protected_column_index']}\n",
    "]\n",
    "\n",
    "favorable_label = 0\n",
    "unfavorable_label = 1  # missclassified == True\n",
    "\n",
    "df_for_aif360 = df[[\"protected_column_index\", \"missclassified\"]]\n",
    "\n",
    "df_for_aif360"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "aif360_dataset = BinaryLabelDataset(\n",
    "    favorable_label=favorable_label,\n",
    "    unfavorable_label=unfavorable_label,\n",
    "    df=df_for_aif360,\n",
    "    label_names=['missclassified'],\n",
    "    protected_attribute_names=['protected_column_index'],\n",
    "    unprivileged_protected_attributes=unprivileged_groups)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_orig_train = BinaryLabelDatasetMetric(\n",
    "    aif360_dataset,\n",
    "    unprivileged_groups=unprivileged_groups,\n",
    "    privileged_groups=privileged_groups)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_orig_train.statistical_parity_difference()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_orig_train.smoothed_empirical_differential_fairness()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(metric_orig_train, open(metric_output, \"wb\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
